#ifndef AHLGREN_PROLOG_TERM
#define AHLGREN_PROLOG_TERM

#include <iostream>
#include <cassert>
#include <vector>
#include <map>
#include <cstring> // memcmp
#include <iterator>
#include <stack>
#include <memory>

#include "useful.h"
#include "global.h"

namespace lp {
	class node_const_iterator;
	class term_const_iterator;
	class body_iterator;
	class body_const_iterator;
	class arg_iterator;
	class arg_const_iterator;


	class term {
	public:
		typedef lp::id_type type_t; // positive = functor with n-arity
		// Note: numbers reflects reverse sorting order except for FUNC (VAR < FLT < INT < CON)
		static const type_t VAR = 1;
		static const type_t EVAR = 2; // variable pointing to external stuff
		static const type_t FLT = 3;
		static const type_t INT = 4;
		static const type_t CON = 5;
		static const type_t LIS = 6;
		//=========== Constructors ===========//
		// Construct unknown term
		term() : type(0) {}
		// Construct term
		term(type_t t, int a) : type(t), ar(a) {}
		term(type_t t, term* p) : type(t), ptr(p) {}
		term(long long i) : type(INT), ival(i) {}
		term(double d) : type(FLT), dval(d) {}
		// Construct functor with id
		static term make_int(long long i) { return term(i); }
		static term make_double(double d) { return term(d); }
		static term make_variable(term* p) { return term(term::VAR,p); }
		static term make_function(type_t f, int a) { assert(f < 0 && a > 0 && !(f == pair_id && a == 2)); return term(f,a); }
		static term make_function(type_t f, unsigned a) { return make_function(f,int(a)); }
		static term make_constant(type_t c) { assert(c < 0); return term(term::CON,c); }
		static term make_list() { return term(term::LIS,2); }
		// Raw constructors
		term(int v1, long long v2) : type(v1), ival(v2) {}

		// Test for types
		bool is_null() const { return type == 0; }

		bool is_function() const { return type < 0 || type == LIS; }
		bool is_function(lp::id_type f, int n) const { return type == f && n == ar; }
		bool is_variable() const { return type == VAR || type == EVAR; }
		bool is_list() const { return type == LIS; }
		bool is_constant() const { return type == CON; }
		bool is_constant(lp::id_type f) const { return is_constant() && arity() == f; }
		bool is_int() const { return type == INT; }
		bool is_double() const { return type == FLT; }

		// Get signature p/n
		std::pair<lp::id_type,int> signature() const;

		// Get/Set values
		void set_type(type_t t) { type = t; }
		type_t get_type() const { return type; }
		void set_fun(type_t id) { type = id; }
		type_t get_fun() const { return type; }

		void set_ptr(term* p) { ptr = p; }
		term* get_ptr() const { return ptr; }
		void set_arity(int a) { ar = a; }
		int arity() const { return ar; }
		void set_const(int id) { ar = id; }
		int get_const() const { return ar; }
		void set_int(long long v) { ival = v; }
		long long get_int() const { return ival; }
		void set_double(double d) { dval = d; }
		double get_double() const { return dval; }

		// Iterate through all terms
		term_const_iterator begin(bool cycle = false) const;
		term_const_iterator end(bool cycle = false) const;
		// Iterate through all nodes
		node_const_iterator node_begin(bool cycle = false) const;
		node_const_iterator node_end(bool cycle = false) const;
		// Iterate through body literals
		body_iterator body_begin();
		body_iterator body_end();
		body_const_iterator body_begin() const;
		body_const_iterator body_end() const;
		// Iterate through arguments
		arg_iterator arg_begin();
		arg_iterator arg_end();
		arg_const_iterator arg_begin() const;
		arg_const_iterator arg_end() const;
		// Access using position (vector of indices)
		template <typename iter> term* get(iter i, iter end);
		template <typename iter> const term* get(iter i, iter e) const;
		term* get(const std::vector<int>& v) { return get(v.begin(),v.end()); }
		const term* get(const std::vector<int>& v) const { return get(v.begin(),v.end()); }

		// Compute number of functors (nodes)
		int size() const;

		// Clause functions
		term* head();
		inline const term* head() const { return const_cast<term*>(this)->head(); }
		term* body();
		inline const term* body() const { return const_cast<term*>(this)->body(); }
		int body_size() const;
		body_iterator body_erase(body_iterator);
		void body_push_back(term* t);

		// Print term
		void print(std::ostream& os) const;

#ifndef NDEBUG
		static void* operator new(std::size_t);
		static void operator delete(void*);
		static void* operator new[](std::size_t size);
		static void operator delete[](void*);
#endif
	protected:
		type_t type;
		union {
			int ar;
			long long ival;
			double dval;
			term* ptr;
		};
	};

	// Is term cyclic?
	const term* is_cyclic(const term* t);

	//// Print in infix notation
	void print_infix(std::ostream& os, const term* t, std::map<const term*,std::string>& vname);
	inline void print_infix(std::ostream& os, const term* t) { std::map<const term*,std::string> m; print_infix(os,t,m); }
	void print_infix_r(std::ostream& os, const term* t, std::map<const term*,std::string>& vname, std::set<const term*>& singletons, std::set<const term*>& funs);
	// Print addresses in memory
	void print_all(std::ostream& os, const term*);
	// Default print
	inline std::ostream& operator<<(std::ostream& os, const term& t) { print_infix(os,&t); return os; }

	// Term Dereferencing
	term* deref(term* t);
	const term* deref(const term* t);
	// Non-cyclic versions
	term* deref_nocycle(term* t, std::set<const term*>& vars);
	inline const term* deref_nocycle(const term* t, std::set<const term*>& vars) { return deref_nocycle(const_cast<term*>(t),vars); }

	// Make a deep copy of a term expression
	//term* copy(const term* t, std::map<const term*,term*>&);
	//inline term* copy(const term* t) { std::map<const term*,term*> m; return copy(t,m); }


	// Make Term from predefined id
	term* make_term(lp::id_type id);
	// Recursively compare two terms (ISO Prolog)
	int rcompare(const term* x, const term* y);
	bool is_variant(const term* x, const term* y, std::map<const term*,const term*>& v1, std::map<const term*,const term*>& v2);
	inline bool is_variant(const term* x, const term* y) { std::map<const term*,const term*> m,n; return is_variant(x,y,m,n); }
	// Evaluate and Numerically compare two terms
	int num_compare(const term* x, const term* y);
	// Hash terms
	//struct term_ptr_hash {
	//	std::size_t operator()(const term* t) const;
	//};

	// Evaluate numeric (throws non_numeric otherwise)
	term eval(const term* t, int max_recursion);
	inline term eval(const term* t) { return eval(t,10000); } // TODO: make recursion limit adjustable parameter
	// Is f a subset of g?
	//bool is_subset(const term* f, const term* g, bool ordered = true);
	// Is term ground?
	//bool is_ground(const term* t, std::set<const term*>& fun);
	//inline bool is_ground(const term* t) { std::set<const term*> s; return is_ground(t,s); }

	//template <typename Fun>
	//bool apply_noncyclic(term* t, Fun fun, std::set<const term*>& vars)
	//{
	//	switch (t->get_type()) {
	//	case term::VAR: case term::EVAR:
	//		if (t->get_ptr() != t) {
	//			auto p = vars.insert(t);
	//			if (!p.second) {
	//				// Variable already used, we have a cycle
	//				if (!fun(nullptr)) return false;
	//				return true;
	//			}
	//			if (!apply_noncyclic(t->get_ptr(),fun,vars)) return false;
	//			vars.erase(t);
	//		} else {
	//			if (!fun(t)) return false;
	//		}
	//		break;
	//	case term::CON: case term::INT: case term::FLT:
	//		if (!fun(t)) return false;
	//		break;
	//	default:
	//		assert(t->is_function());
	//		if (!fun(t)) return false;
	//		for (int k = 1; k <= t->arity(); ++k) {
	//			if (!apply_noncyclic(t+k,fun,vars)) return false;
	//		}
	//	}
	//	return true;
	//}

	//template <typename Fun>
	//bool apply_noncyclic(term* t, Fun fun)
	//{
	//	std::set<const term*> vars;
	//	return apply_noncyclic(t,fun,vars);
	//}

	//template <typename Fun>
	//bool apply_noncyclic(const term* t, Fun fun)
	//{
	//	std::set<const term*> vars;
	//	return apply_noncyclic(const_cast<term*>(t),fun,vars);
	//}


	// Iterate through body literals
	class body_const_iterator : public std::iterator<std::forward_iterator_tag,term> {
	public:
		body_const_iterator() : ptr(nullptr), at_end() {}
		body_const_iterator(const term* r);
		body_const_iterator& operator++();
		body_const_iterator operator++(int) { auto i = *this; ++*this; return i; }
		bool operator==(const body_const_iterator& i) const { return ptr == i.ptr; }
		bool operator!=(const body_const_iterator& i) const { return !(*this == i); }
		const term& operator*() const { return *deref(ptr); }
		const term* operator->() const { return deref(ptr); }
		const term* noderef_ptr() const { return ptr; }
	protected:
		const term* ptr;
		bool at_end;
		body_const_iterator(const term* r, bool e) : ptr(r), at_end(e) {}
	};

	class body_iterator : public body_const_iterator {
	public:
		body_iterator() : body_const_iterator() {}
		body_iterator(term* r) : body_const_iterator(r) {}
		body_iterator& operator++();
		body_iterator operator++(int) { auto i = *this; ++*this; return i; }
		bool operator==(const body_iterator& i) const { return ptr == i.ptr; }
		bool operator!=(const body_iterator& i) const { return !(*this == i); }
		term& operator*() { return const_cast<term&>(*deref(ptr)); }
		term* operator->() { return const_cast<term*>(deref(ptr)); }
		term* noderef_ptr() { return const_cast<term*>(ptr); }
		operator body_const_iterator() { return body_const_iterator(ptr,at_end); }
	};

	// Iterate through arguments
	class arg_const_iterator : public std::iterator<std::bidirectional_iterator_tag,term> {
	public:
		arg_const_iterator(const term* r, int s = 1) : root(r), i(s) {}
		arg_const_iterator& operator++() { ++i; return *this; }
		arg_const_iterator operator++(int) { auto s = *this; ++*this; return s; }
		arg_const_iterator& operator--() { --i; return *this; }
		arg_const_iterator operator--(int) { auto s = *this; --*this; return s; }
		bool operator==(const arg_const_iterator& s) const { return root == s.root && i == s.i; }
		bool operator!=(const arg_const_iterator& s) const { return !(*this == s); }
		const term& operator*() const { return *deref(root+i); }
		const term* operator->() const { return deref(root+i); }
		const term* noderef_ptr() const { return root+i; }
	protected:
		const term* root;
		int i;
	};

	class arg_iterator : public arg_const_iterator {
	public:
		arg_iterator(term* r, int s = 1) : arg_const_iterator(r,s) {}
		arg_iterator& operator++() { ++i; return *this; }
		arg_iterator operator++(int) { auto s = *this; ++*this; return s; }
		arg_iterator& operator--() { --i; return *this; }
		arg_iterator operator--(int) { auto s = *this; --*this; return s; }
		bool operator==(const arg_iterator& s) const { return root == s.root && i == s.i; }
		bool operator!=(const arg_iterator& s) const { return !(*this == s); }
		term& operator*() { return const_cast<term&>(*deref(root+i)); }
		term* operator->() { return const_cast<term*>(deref(root+i)); }
		term* noderef_ptr() { return const_cast<term*>(root)+i; }
	};


	class node_const_iterator : public std::iterator<std::forward_iterator_tag,term> {
	public:
		typedef std::forward_iterator_tag iterator_category;
		typedef term value_type;
		typedef const term* pointer;
		typedef const term& reference;

		// For begin: fp == rp == node
		node_const_iterator(const term* fp, bool cycle) : nptr(fp), accept_cycle(cycle) { }
		// for end: fp == nullptr, rp == root
		node_const_iterator(nullptr_t) : nptr(nullptr) { }

		// Prefix increment
		node_const_iterator& operator++();

		// Postfix increment
		node_const_iterator operator++(int) {
			node_const_iterator tmp(*this);
			++(*this);
			return tmp;
		}

		bool operator==(const node_const_iterator& di) const { return nptr == di.nptr; }
		bool operator!=(const node_const_iterator& di) const { return !(*this == di); }
		reference operator*() const { return *nptr;  }
		pointer operator->() const { return nptr; }
	protected:
		pointer nptr;
		bool accept_cycle;
		std::map<const term*,const term*> pars;
		std::set<const term*> vars;
	};


	class term_const_iterator : public std::iterator<std::forward_iterator_tag,term> {
	public:
		typedef std::forward_iterator_tag iterator_category;
		typedef term value_type;
		typedef const term* pointer;
		typedef const term& reference;

		// For begin: fp == rp == node
		term_const_iterator(const term* fp, bool cycle) : ni(fp,cycle) { }
		// for end: fp == nullptr, rp == root
		term_const_iterator(nullptr_t, bool cycle) : ni(nullptr,cycle) { }

		// Prefix increment
		term_const_iterator& operator++() {
			do {
				++ni;
			} while (ni != node_const_iterator(nullptr) && ni->is_variable() && &*ni != ni->get_ptr());
			return *this;
		}

		// Postfix increment
		term_const_iterator operator++(int) {
			term_const_iterator tmp(*this);
			++(*this);
			return tmp;
		}

		bool operator==(const term_const_iterator& di) const { return ni == di.ni; }
		bool operator!=(const term_const_iterator& di) const { return !(*this == di); }
		reference operator*() const { return *ni;  }
		pointer operator->() const { return &*ni; }
	protected:
		node_const_iterator ni;
	};

}


namespace lp {
	inline term_const_iterator term::begin(bool c) const { return term_const_iterator(this,c); }
	inline term_const_iterator term::end(bool c) const { return term_const_iterator(nullptr,c); }
	inline node_const_iterator term::node_begin(bool c) const { return node_const_iterator(this,c); }
	inline node_const_iterator term::node_end(bool c) const { return node_const_iterator(nullptr,c); }
	inline arg_iterator term::arg_begin() { return is_function() ? arg_iterator(this) : arg_iterator(nullptr); }
	inline arg_iterator term::arg_end() { return is_function() ? arg_iterator(this,this->arity()+1) : arg_iterator(nullptr); }
	inline arg_const_iterator term::arg_begin() const { return is_function() ? arg_const_iterator(this) : arg_const_iterator(nullptr); }
	inline arg_const_iterator term::arg_end() const { return is_function() ? arg_const_iterator(this,this->arity()+1) : arg_const_iterator(nullptr); }

	inline int term::size() const { return (int) std::distance(begin(),end()); }
	inline int term::body_size() const { return (int) std::distance(body_begin(),body_end()); }

	// Dereferencing
	inline term* deref(term* t)
	{
		while (t->is_variable() && t->get_ptr() != t) {
			t = t->get_ptr();
		}
		return t;
	}
	inline const term* deref(const term* t) { return deref(const_cast<term*>(t)); }


	inline std::pair<lp::id_type,int> term::signature() const
	{
		typedef std::pair<lp::id_type,int> sign_t;
		switch (type) {
		case VAR: return sign_t(0,int(VAR));
		case EVAR: return sign_t(0,int(EVAR));
		case 0: return sign_t(0,0);
		case INT: return sign_t(0,int(INT));
		case FLT: return sign_t(0,int(FLT));
		case CON: return sign_t(get_const(),0);
		case LIS: return sign_t(pair_id,2);
		default: 
			assert(is_function());
			return sign_t(get_type(),arity());
		}
	}

}


#ifndef NDEBUG
namespace lp {
	void check_memory();
}
#endif


#endif


